-
Notifications
You must be signed in to change notification settings - Fork 540
[Common] Split cast/gated kernels by scaling mode #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request refactors the large cast_kernels.cuh and cast_gated_kernels.cuh files into smaller, more organized header files structured by scaling mode. This improves code maintainability, readability, and navigation by creating specialized headers for different quantization and scaling implementations.
- Breaks down monolithic headers into focused, scaling-mode-specific files
- Reorganizes code structure without modifying functionality or behavior
- Creates dispatcher files to coordinate between different scaling implementations
Reviewed Changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| transformer_engine/common/util/cast_kernels.cuh | Removed all content - entire file deleted as part of refactoring |
| transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh | NVFP4 quantize with transpose functionality, updated file path and namespacing |
| transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | New file containing NVFP4-specific quantization kernels |
| transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | New file containing NVFP4 dequantization functionality |
| transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | New file with core NVFP4 utility functions and device operations |
| transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | New file containing MXFP8 quantization kernels |
| transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | MXFP8 gated operations, significantly reduced from original gated kernels file |
| transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh | New file containing MXFP8 dequantization functionality |
| transformer_engine/common/cast/fp8/quantize_fp8.cuh | New file containing FP8 quantization kernels |
| transformer_engine/common/cast/fp8/gated_fp8.cuh | New file containing FP8 gated operations |
| transformer_engine/common/cast/fp8/dequantize_fp8.cuh | New file containing FP8 dequantization functionality |
| transformer_engine/common/cast/dispatch/quantize.cuh | New dispatcher file coordinating quantization across scaling modes |
| transformer_engine/common/cast/dispatch/gated.cuh | New dispatcher file coordinating gated operations across scaling modes |
| transformer_engine/common/cast/dispatch/dequantize.cuh | New dispatcher file coordinating dequantization across scaling modes |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
| // This kernel supports only two scaling cases: | ||
| // 1. r16c0 - Rowwise NVFP4 | ||
| // 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 | ||
| template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually support fused activation-cast kernels for NVFP4? If not, we should remove these template arguments so that we don't compile unnecessary kernels and so we prevent users from accidentally calling them. We should also remove them from the kernel, and modify quantize_helper so it errors out if you attempt something invalid.
| template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intentionally left activation template arguments and all the activation related logic untouched, so we can easily enable it when/if it becomes the part of the FP4 recipe.
@ptrendx, should we keep it, or I just go ahead and clean up the kernel?
I also didn't want to add any functionality related modifications to this PR to not overwhelm it, and to do it separately in a following PRs. Since there are some parts of the NVFP4 code that need to be reviewed/changed anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't support them, we should at least error out if you attempt to run them. Avoiding unnecessary compilations would also be useful so we don't blow up compile time and binary size.
I'm fine deferring this if we want this PR to minimize functional changes, but we should aim to catch more of these errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@timmoon10 @Oleg-Goncharov Let's minimize changes in this PR and just do the code movement here. Otherwise it will be very hard to properly review if the functionality was not altered.
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM once we iron out the test failures.
9202e6d to
86dd987
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest update applies formatting-only changes to transformer_engine/common/cast/nvfp4/core_nvfp4.cuh, aligning the file with the project's clang-format configuration (Google-based style, 100-char column limit, 2-space indentation). No functional or behavioral modifications were made—function signatures and error macros were reformatted to improve consistency and readability. This change ensures that the NVFP4 core utilities, which handle FP4 quantization and conversion operations via inline PTX assembly, adhere to the repository's established formatting standards.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 5/5 | Formatting-only changes: function signatures and error macros reformatted to match clang-format style; no functional modifications. |
Confidence score: 5/5
- This PR update is safe to merge with minimal risk, as it contains only formatting changes with no functional modifications.
- Score reflects that the changes are purely cosmetic (clang-format enforcement) and cannot introduce bugs, regressions, or behavioral changes; all function logic remains identical.
- No files require special attention; this is a straightforward formatting pass to ensure style consistency across the NVFP4 core utilities.
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. This update addresses formatting inconsistencies in core_nvfp4.cuh by reformatting function signatures to comply with the project's .clang-format style guide (Google-based, 100-character column limit). The changes are purely cosmetic—multi-line function signatures like compute_decoding_scaling_factor and mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding are now consistently split across lines with proper indentation, while compute_global_encode_scaling_factor_FP4 is collapsed to a single line. Additionally, the #else branches that threw errors when FP4_TYPE_SUPPORTED is undefined have been removed, simplifying the code structure. This refactoring is part of the broader PR goal to split large cast kernel headers into smaller, more maintainable files organized by scaling mode. The reformatting improves readability and navigation within device code, aligning with the project's style enforcement strategy (cpplint, clang-format, pre-commit hooks).
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 5/5 | Formatting-only changes: function signatures reformatted for readability and #else error branches removed. |
Confidence score: 5/5
- This PR is safe to merge with minimal risk as the changes are purely cosmetic and do not modify any logic.
- Score reflects formatting-only changes with no impact on compiled code or behavior; all modifications align with project style guidelines.
- No files require special attention; this is a straightforward formatting cleanup.
1 file reviewed, no comments
|
/te-ci |
| } | ||
|
|
||
| const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) && | ||
| is_supported_by_CC_100(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also really don't like the name of this function, but let's fix that in the
| NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", | ||
| "by 32, got input of shape ", input.data.shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error is potentially misleading since it is possible that the reason for use_tma_kernels to be false is lack of rowwise or columnwise output not being there. Also this is actually wrong since I believe the mxfp8 kernels can support only 1 of those outputs, right? Looking at it the second time it seems that the logic of setting the variables for rowwise and columnwise output is convoluted and not at all understandable. I will make a comment there.
| namespace detail { | ||
| using Empty = transformer_engine::Empty; | ||
| __device__ inline float identity(float value, const Empty &) { return value; } | ||
| } // namespace detail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is currently used by the CastVectorizedUnaryKernelLauncher and CastVectorizedUnaryGradKernelLauncher below
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
…s from the NVFP4 transpose test suite Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
c0f4a1e to
b764dea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR.
This update adds a new return_max_logit feature to fused attention APIs across PyTorch, JAX, and C++ layers, while completing the cast kernel refactoring. The attention API now optionally returns the maximum attention logit per head (useful for Muon optimizer integration and numerical stability analysis). The change threads a boolean parameter through all attention forward paths and disables FP8/F16_max512 backends when return_max_logit=true since only F16_arbitrary_seqlen supports it. Additionally, the nvidia-mathdx dependency was removed from build configurations, sigmoidf was added to math.h, submodule pointers were updated, and several utility refactors (RNG API simplification, tensor creation helpers) were applied to align with the reorganized kernel structure. The bulk of the cast kernel split (FP8, MXFP8, NVFP4 into separate headers) was completed in earlier commits and is not repeated here.
PR Description Notes:
- The PR description states "No functional or behavior changes: code is moved, not modified," but this review includes the
return_max_logitfeature addition, which is a functional change. The description should be updated to reflect this.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 4/5 | Adds return_max_logit parameter to C API; disables FP8/F16_max512 backends when enabled |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4/5 | Implements separate Max/Sum_Exp tensor allocation when return_max_logit=true |
| transformer_engine/pytorch/attention/dot_product_attention/backends.py | 2/5 | Adds max_logit support but has inconsistent FP8 handling and unverified unpacking logic |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 3/5 | Integrates max_logit with context-parallel attention; requires careful None-checking across all code paths |
| transformer_engine/jax/csrc/extensions/attention.cpp | 2/5 | Critical bug: hardcodes false for deterministic parameter instead of propagating actual value |
| pyproject.toml | 3/5 | Removes nvidia-mathdx build dependency without documentation; may break builds if code depends on it |
| transformer_engine/common/util/curanddx.hpp | 4.5/5 | New Philox4x32 RNG implementation extracted during refactor; self-contained and correct |
| transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | 3/5 | Potential boundary check error at line 128; macro instead of constexpr at line 49 |
| transformer_engine/common/cast/fp8/quantize_fp8.cuh | 3/5 | Incorrect comments claiming elt is 0 when no activation applied (lines 168, 178) |
| transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 3.5/5 | Uninitialized scaling_type variable if neither rowwise nor colwise scaling enabled |
Confidence score: 3/5
- This PR introduces functional changes (max_logit feature) beyond the stated "code is moved, not modified" scope, and includes critical bugs in the JAX attention extension that will silently ignore user-specified deterministic behavior.
- Score lowered due to: (1) JAX extension bug where deterministic parameter is hardcoded to
false, (2) inconsistent FP8 handling in PyTorch backends where max_logit is initialized but never populated, (3) unverified tuple unpacking logic that assumes fused_attn_fwd return structure without clear guarantees, (4) potential uninitialized variables in MXFP8/NVFP4 kernels, and (5) missing documentation fornvidia-mathdxremoval. - Pay close attention to
transformer_engine/jax/csrc/extensions/attention.cpp(deterministic parameter bug),transformer_engine/pytorch/attention/dot_product_attention/backends.py(FP8 max_logit handling),transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py(None-checking for max_logit operations), and the NVFP4/MXFP8 quantization kernels with potential uninitialized variables.
Additional Comments (17)
-
transformer_engine/common/util/curanddx.hpp, line 19-23 (link)style: pointer parameter alignment doesn't match
.clang-format(should beunsigned int*→unsigned int *) -
setup.py, line 165-174 (link)style: assertion will fail if a submodule is at a different commit than expected (starts with '+'); consider handling this case explicitly or documenting this behavior more clearly in the error message
-
setup.py, line 180-181 (link)style: silently returns on any subprocess error; consider logging the exception or at least distinguishing between expected failures (not a git repo) vs. unexpected errors
-
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh, line 690-697 (link)logic: scaling_type is used uninitialized if neither rowwise nor columnwise scaling is enabled. Add else branch with NVTE_CHECK or default initialization.
-
transformer_engine/pytorch/cpp_extensions/fused_attn.py, line 331 (link)style: amax_dims logic assumes thd vs bshd/sbhd are the only possibilities; if new layouts are added, this could silently produce incorrect max_logit shape
-
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 697 (link)logic: unpacking operator
*max_logitexpects zero or one element, but thereturn_max_logitflag determines whetherfused_attn_fwdreturns it—check that the unpacking consistently handles both cases. Doesfused_attn_fwdalways return a tuple with*max_logitin the correct position whenreturn_max_logit=True, and does it skip that return value entirely whenreturn_max_logit=False? -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1164-1165 (link)logic: initializes
max_logit_per_stepandmax_logittoNonebut later indexes them—verify that indexing/assignment only happens afterreturn_max_logit=Trueinitialization -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1254-1257 (link)logic: list comprehension creates
max_logit_per_steptensors only whenreturn_max_logit=Trueand non-FP8; ensure all subsequent accesses tomax_logit_per_step[i]guard on the same condition -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1619-1623 (link)logic: computes
max_logitviatorch.cloneandtorch.maximumonly whenreturn_max_logit=True; confirm thatmax_logit_per_stepis neverNoneat these indices when the condition is met -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1629-1632 (link)logic: all-reduces
max_logitonly whenreturn_max_logit=True; ensure thatmax_logitis notNoneat this point -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2754 (link)logic: unpacking
*max_logit_fromfused_attn_fwdrequires consistent return tuple structure whenreturn_max_logit=True. Doesfused_attn_fwdguarantee that it returns a tuple with themax_logitelement appended whenreturn_max_logit=True, and omits it otherwise? -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2776-2777 (link)logic: assigns
max_logit_per_step[i]frommax_logit_[0]only whenreturn_max_logit=True; ensuremax_logit_is not empty -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2812-2821 (link)logic: clones and computes maximum of
max_logit_per_steponly whenreturn_max_logit=True; verify that allmax_logit_per_stepentries are initialized tensors, notNone -
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 3261 (link)logic: unpacking
*max_logitfromfused_attn_fwdrequires consistent return tuple structure whenreturn_max_logit=True -
transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1146 (link)logic: max_logit initialized as None but only populated for non-FP8 path. FP8 path (lines 1174-1233) doesn't compute max_logit, leading to inconsistent behavior. Should FP8 path also return max_logit when requested, or is this intentionally unsupported?
-
transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1250 (link)logic: Unpacking
*max_logitfrom fused_attn_fwd assumes optional return. If fused_attn_fwd doesn't return max_logit, this unpacking will fail. Does fused_attn_fwd always return max_logit as an optional value, or only when return_max_logit=True? -
transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1936-1938 (link)style: Conditional return logic duplicated across UnfusedDotProductAttention and FusedAttention. Consider extracting to a helper function to reduce duplication.
49 files reviewed, 37 comments
| output->data.dtype, OType, | ||
|
|
||
| constexpr int nvec = 32 / sizeof(OType); | ||
| DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Multiple statements on one line. Split initialization and assignment for clarity:
| DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr); | |
| DequantizeParam p; | |
| p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup.
| template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> | ||
| void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, | ||
| cudaStream_t stream) { | ||
| using namespace dispatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: redundant using namespace dispatch inside the dispatch namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| float (*DActOP)(float, const ParamOP &)> | ||
| void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, | ||
| NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { | ||
| using namespace dispatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: redundant using namespace dispatch inside the dispatch namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias, | ||
| workspace, nullptr, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: indentation of the continuation line changed from 4 to 1 space, which deviates from the project's .clang-format (ContinuationIndentWidth: 4)
| output_vec[my_output_index + i] = out; | ||
| } | ||
| } | ||
| #endif // CUDA_VERSION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Comment should be FP4_TYPE_SUPPORTED to match the #if on line 32, not CUDA_VERSION
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| partial_dbias += elt; | ||
| } | ||
| } else { | ||
| // If no activation, elt is 0 so we can safely do this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: comment says 'elt is 0' but this is incorrect – elt is only 0 if out_of_bounds is true; for in-bounds elements elt equals the loaded value. The logic relies on conditional accumulation in the IS_DACT branch but assumes unconditional safety here. Is the assumption that out-of-bounds elements are loaded as zero guaranteed by the TMA copy operation?
| amax = fmaxf(amax, fabsf(elt)); | ||
| } | ||
| } else { | ||
| // If no activation, elt is 0 so we can safely do this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: same issue as line 168 – comment incorrectly states 'elt is 0' but elt is only guaranteed zero for out-of-bounds accesses. For in-bounds elements without activation, elt is the loaded input value.
| * \brief Common functions in quantize. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: header guard mismatch – file is cast/core/common.cuh but guard is QUANTIZE_CORE_COMMON_CUH_
| ************************************************************************/ | ||
|
|
||
| /*! \file common.cuh | ||
| * \brief Common functions in quantize. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: file comment says "in quantize" but this lives in cast/core/ – update to "Common functions in cast." or similar
| template <typename IType> | ||
| void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: template parameter IType name suggests input type, but it is used as OType (output) inside reduce_dbias_kernel; consider renaming to OType for clarity
Signed-off-by: Oleg Goncharov <[email protected]>
dc04670 to
703556c
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest commit addresses formatting inconsistencies introduced during the previous code reorganization. Seven files received minor formatting adjustments: math.h condensed a one-line sigmoidf function to match the project's .clang-format rule allowing short functions on a single line; dispatch/quantize.cuh split a long boolean condition across two lines; dispatch/gated.cuh, gated_mxfp8.cuh, and core_nvfp4.cuh adjusted line breaks and continuation indentation in function signatures and error messages; quantize_transpose_nvfp4.cuh reformatted two PTX function calls; and gated_fp8.cuh received both formatting changes (line-continuation indentation) and a critical functional modification (removal of null-pointer validation checks for scale_inv and scale tensors, plus a macro change that broadens permitted output types beyond FP8). These changes ensure consistency with the project's 100-character column limit and continuation-indent rules as part of the broader refactoring to split monolithic cast kernel headers by scaling mode.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/util/math.h | 5/5 | Reformatted sigmoidf from multi-line to single-line (pure formatting) |
| transformer_engine/common/cast/dispatch/quantize.cuh | 5/5 | Split long boolean condition across two lines (formatting only) |
| transformer_engine/common/cast/dispatch/gated.cuh | 5/5 | Reformatted error messages and function calls for readability (no logic changes) |
| transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | 5/5 | Reformatted function signatures to break parameters across lines (cosmetic only) |
| transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 4/5 | Reformatted PTX calls but introduced 1-space continuation indent (violates .clang-format) |
| transformer_engine/common/cast/fp8/gated_fp8.cuh | 2/5 | Removed null-pointer checks for scale tensors and changed type-switch macro, risking null dereferences and incorrect behavior for non-FP8 outputs |
| transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 4/5 | Reordered includes and introduced nested namespaces; removed trailing whitespace; preserved commented-out code for numerical matching |
Confidence score: 2/5
- This PR introduces critical safety regressions in
gated_fp8.cuhthat could cause null-pointer dereferences and incorrect scaling behavior for non-FP8 outputs, despite being labeled a "no functional changes" refactoring - Score reflects a dangerous removal of validation checks (lines 277–280 in
gated_fp8.cuh) and a macro change that may silently break FP8-specific scaling logic for other data types; additional style violations inquantize_transpose_nvfp4.cuh(1-space vs. 4-space continuation indent) and include-order changes incore_nvfp4.cuhthat could affect compilation - Pay extremely close attention to
transformer_engine/common/cast/fp8/gated_fp8.cuh(removed null checks, macro change) and verify that all callers guarantee scale-tensor allocation before invokingcast_gated_tma; reviewquantize_transpose_nvfp4.cuhandcore_nvfp4.cuhfor style compliance
7 files reviewed, 5 comments
| template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> | ||
| void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, | ||
| cudaStream_t stream) { | ||
| using namespace dispatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| float (*DActOP)(float, const ParamOP &)> | ||
| void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, | ||
| NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { | ||
| using namespace dispatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| if constexpr (IS_DBIAS || IS_DACT) { | ||
| // backward - input is incoming gradient | ||
| input_tensor = convertNVTETensorCheck(grad); | ||
| activation_input_tensor = convertNVTETensor(input); | ||
| } else { | ||
| // forward = input is activation input | ||
| input_tensor = convertNVTETensorCheck(input); | ||
| activation_input_tensor = nullptr; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest I really hoped that as part of this refactor we could fix this part of this API since it is super confusing and actually have consistent naming of the function arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored the dispatch/quantize.cuh in a similar way as dispatch/gated.cuh, splitting into two helpers (FWD and BWD), there is a bunch of code duplicated, but now the logic of the code should be more intuitive.
| output->data.dtype, OType, | ||
|
|
||
| constexpr int nvec = 32 / sizeof(OType); | ||
| DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup.
| __device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, | ||
| const float S_enc) { | ||
| // constexpr float rcp_6f = 1.0f / 6.0f; | ||
| // const float S_dec_b = block_amax * rcp_6f; | ||
| // const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc); | ||
| // return S_dec_b_fp8; | ||
| // NOTE: Divide by 6.0f is not elegant and not efficient. | ||
| // However, this is part of the emulation code to ensure exact match. | ||
| using namespace detail; | ||
| constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f; | ||
| const float S_dec_b = block_amax / fp4_max * S_enc; | ||
| return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, it seems to be basically the same function written twice.
| // If scale is infinity, return max value of float32 | ||
| global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max); | ||
| // If global amax is 0 or infinity, return 1 | ||
| if (global_amax == 0.0f || global_encode_scale == 0.0f) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this is wrong comment, it actually points the correct fact that something is skippable here - the global_amax == 0 check. This is because if amax is 0 then the scale is infinity and would be clamped by the fminf already (and if the amax is 0 we do not really care what the value of the scale is as long as it is finite so multiplied by 0 would not produce NaN). This kernel is so tiny though that it doesn't matter.
| output_vec[my_output_index + i] = out; | ||
| } | ||
| } | ||
| #endif // CUDA_VERSION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
| const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| const size_t x = thread_idx % M; | ||
| const size_t y = thread_idx / M; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's maybe correct and I did intend to look into this kernel anyway since it doesn't work in some cases, but that is out of scope for this PR.
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The latest update addresses several issues raised in previous reviews: (1) dispatcher refactoring splits the monolithic quantize_helper into separate quantize_fwd_helper and quantize_bwd_helper functions, removing unused template parameters from forward paths; (2) parameter reordering in gated kernels moves gated_input/input before grad to establish a consistent convention across scaling modes (forward input, then gradient); (3) naming improvements rename helpers to quantize_gated_fwd_helper and quantize_gated_bwd_helper for clarity; (4) cosmetic fixes correct #endif comments, remove redundant using namespace dispatch declarations, and adjust formatting to match .clang-format conventions. All changes are code-movement or signature-reordering refactorings with no functional modifications—the dispatch logic, kernel launch sites, and validation checks remain identical to the previous version.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/util/math.h | 5/5 | Collapsed sigmoidf function from three lines to one line for consistency with other short device functions |
| transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 5/5 | Moved NVFP4 dequantization kernel to dedicated header; corrected #endif comment from CUDA_VERSION to FP4_TYPE_SUPPORTED |
| transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 5/5 | Reordered two #include directives and removed trailing whitespace; no logic changes |
| transformer_engine/common/activation/activation_template.h | 5/5 | Updated activation function templates to call renamed _fwd_helper and _bwd_helper dispatch functions with simplified template parameters |
| transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | 5/5 | Swapped quantize_gated parameter order (gated_input before grad) to match other scaling modes; minor formatting adjustments |
| transformer_engine/common/cast/fp8/gated_fp8.cuh | 5/5 | Reordered function parameters in cast_gated_tma and cast_gated_bwd to place input before gradient for consistency |
| transformer_engine/common/cast/dispatch/gated.cuh | 4/5 | Renamed helpers to _fwd_helper and _bwd_helper; reordered arguments in fp8::cast_gated_tma and mxfp8::quantize_gated calls to pass input/gated_input before grad |
| transformer_engine/common/cast/cast.cu | 3.5/5 | Refactored quantization API entry points to use separate quantize_fwd_helper and quantize_bwd_helper dispatchers; reordered parameters (input, activation_input, output, ...) for backward helpers |
| transformer_engine/common/cast/dispatch/quantize.cuh | 0/5 | File listed but no summary provided; changes unknown |
Confidence score: 3.5/5
- This PR is mostly safe to merge but requires careful verification that parameter reordering matches updated kernel signatures across all call sites.
- Score lowered because (1) the parameter reordering in
cast.cufor backward helpers (input,activation_input,output) may break call sites if upstream callers were not updated, (2) the argument-order change in gated kernel calls (fp8::cast_gated_tma,mxfp8::quantize_gated) must be validated against the kernel signatures to ensure input/gated_input is now expected before grad, and (3) filetransformer_engine/common/cast/dispatch/quantize.cuhhas no summary provided, so its changes are unreviewed. - Pay close attention to
transformer_engine/common/cast/cast.cu(verify all callers ofnvte_quantize_dbias*pass parameters in the new order) andtransformer_engine/common/cast/dispatch/gated.cuh(confirmfp8::cast_gated_tmaandmxfp8::quantize_gatedkernel signatures accept input/gated_input before grad).
9 files reviewed, 3 comments
| const Tensor *dummy_input_tensor = nullptr; | ||
| Tensor *dummy_dbias_tensor = nullptr; | ||
| Tensor *dummy_workspace_tensor = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: dummy_input_tensor, dummy_dbias_tensor, and dummy_workspace_tensor are declared but never used in the FP8 forward path. Consider removing them or documenting why they're passed as null. Are these dummy pointers required by the kernel signature, or can the kernel signature be simplified for the forward case?
| const Tensor *dummy_input_tensor = nullptr; | ||
| Tensor *dummy_dbias_tensor = nullptr; | ||
| Tensor *dummy_workspace_tensor = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: duplicate dummy variable declarations (same as lines 61–63). Consider factoring these into a shared helper or refactoring the kernel signature to avoid passing unused parameters
| dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>( | ||
| inputs[i], outputs[i], nullptr, detail::get_compute_stream(i % num_streams)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: In multi-tensor quantize, quant_configs parameter is declared on line 138 but never used. Should each call to quantize_fwd_helper pass quant_configs (or an indexed config)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we should pass the quant_config (quant_configs is actually misleading name since it suggests there are multiple of those) to those quantize functions.
|
Regarding the |
| dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>( | ||
| inputs[i], outputs[i], nullptr, detail::get_compute_stream(i % num_streams)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we should pass the quant_config (quant_configs is actually misleading name since it suggests there are multiple of those) to those quantize functions.
| input, activation_input, output, dbias, workspace, nullptr, stream); | ||
| } | ||
|
|
||
| void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move those functions involving activations to the activation-specific files? That way we could make sure that we use fast math only for the activations (and then maybe actually turn it on by default?) and not for the entire cast.cu file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR restructures cast and gated kernel code by splitting the monolithic cast_kernels.cuh and cast_gated_kernels.cuh files into a cleaner directory hierarchy organized by scaling mode (FP8, MXFP8, NVFP4). The refactoring introduces a dispatch layer that routes operations based on scaling_mode, with format-specific implementations in dedicated subdirectories.
Key Changes:
- Moved
util/cast.cutocast/cast.cuand created new directory structure:cast/{core,dispatch,fp8,mxfp8,nvfp4}/ - Deleted 2188-line
cast_kernels.cuh, replacing it with 10+ focused header files - Added dispatch layer (
dispatch/*.cuh) to route operations by scaling mode - Extracted common utilities to
core/common.cuh - Updated CMakeLists.txt and include paths accordingly
Issues Found:
- Critical: Removed null-pointer validation for
scale_inv.dptrandscale.dptrin FP8 gated kernels (previously validated in original code) - Critical: Changed type switch macro from
FP8ONLYto genericOUTPUTin gated kernels, potentially allowing non-FP8 output types where FP8-specific scaling logic is expected - Missing bounds check in NVFP4 dequantization could cause out-of-bounds memory access
- Unused
quant_configsparameter in multi-tensor quantize path - Several misleading comments and minor style issues
Confidence Score: 3/5
- This PR has moderate risk due to removed validation logic and type safety changes in FP8 gated kernels that could cause runtime errors
- Score reflects that while the refactoring structure is sound and most code is cleanly extracted, there are two critical issues in the FP8 gated kernel path: (1) removed null-pointer checks for scale tensors that could cause segfaults, and (2) broadened type constraints that may allow incorrect types through. The NVFP4 bounds check issue is also concerning. These are functional changes disguised as pure refactoring.
transformer_engine/common/cast/fp8/gated_fp8.cuhrequires immediate attention for restored validation and type safety.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhneeds bounds checking.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/cast.cu | 4/5 | Moved from util/cast.cu. Updated include paths to new cast/ directory structure. Entry point functions unchanged except for delegation to new dispatch layer. |
| transformer_engine/common/cast/dispatch/quantize.cuh | 3/5 | New dispatcher for quantization operations. Routes to FP8/MXFP8/NVFP4 implementations based on scaling mode. Contains unused quant_configs parameter in multi-tensor path (line 138). |
| transformer_engine/common/cast/fp8/quantize_fp8.cuh | 3/5 | Extracted FP8 quantization kernels from original cast_kernels.cuh. Contains misleading comments about out-of-bounds element handling (lines 168, 178). |
| transformer_engine/common/cast/fp8/gated_fp8.cuh | 2/5 | Extracted FP8 gated activation kernels. Removed null-pointer validation for scale_inv.dptr and scale.dptr that existed in original code. Changed macro from FP8ONLY to generic OUTPUT, potentially allowing non-FP8 types. |
| transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | 3/5 | Extracted NVFP4 quantization kernels. Contains potential boundary check issue at line 128 comparing against wrong dimension. |
| transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 2/5 | Extracted NVFP4 dequantization kernels. Missing bounds check on thread_idx could cause out-of-bounds memory access (line 38). Incorrect comment reference to CUDA_VERSION instead of FP4_TYPE_SUPPORTED. |
Sequence Diagram
sequenceDiagram
participant API as nvte_quantize APIs<br/>(cast.cu)
participant Dispatch as Dispatch Layer<br/>(dispatch/*.cuh)
participant FP8 as FP8 Kernels<br/>(fp8/*.cuh)
participant MXFP8 as MXFP8 Kernels<br/>(mxfp8/*.cuh)
participant NVFP4 as NVFP4 Kernels<br/>(nvfp4/*.cuh)
participant Core as Core Utilities<br/>(core/common.cuh)
Note over API,Core: Refactoring: Split by Scaling Mode
API->>Dispatch: quantize_fwd_helper()
Dispatch->>Dispatch: Check output->scaling_mode
alt DELAYED_TENSOR_SCALING
Dispatch->>FP8: quantize<FP8>()
FP8->>Core: Use common helpers
FP8-->>Dispatch: FP8 quantized output
else MXFP8_1D_SCALING
Dispatch->>MXFP8: quantize<MXFP8>()
MXFP8->>Core: Use common helpers
MXFP8-->>Dispatch: MXFP8 quantized output
else NVFP4_1D_SCALING
Dispatch->>NVFP4: quantize_transpose()
NVFP4->>Core: Use common helpers
NVFP4-->>Dispatch: NVFP4 quantized output
end
Dispatch-->>API: Return quantized tensor
Note over API,NVFP4: Gated operations follow similar pattern
API->>Dispatch: quantize_gated_fwd_helper()
Dispatch->>Dispatch: Check scaling_mode
alt DELAYED_TENSOR_SCALING
Dispatch->>FP8: cast_gated_fwd()
else MXFP8_1D_SCALING
Dispatch->>MXFP8: quantize_gated()
end
20 files reviewed, no comments
…t-math compilation list Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR refactors the cast and activation kernel code for better organization and maintainability:
- Activation Functions Extracted: Moved GELU, ReLU, and SwiGLU activation functions from
cast.cuinto separate dedicated files (gelu.cu,relu.cu,swiglu.cu) undercommon/activation/ - Kernel Organization by Scaling Mode: Split the large monolithic
cast_kernels.cuhandcast_gated_kernels.cuhinto smaller, focused headers organized by scaling mode (FP8, MXFP8, NVFP4) undercommon/cast/subdirectories - Dispatch Layer: Created new dispatch layer (
dispatch/quantize.cuh,dispatch/dequantize.cuh,dispatch/gated.cuh) that routes to appropriate scaling-mode-specific implementations - Fast-Math Option: Added
NVTE_BUILD_ACTIVATION_WITH_FAST_MATHCMake option (defaults to ON) to enable--use_fast_mathcompilation for activation kernels
The refactoring is purely structural - no functional changes to kernel logic or behavior. The new organization improves navigability and makes it easier to extend specific scaling paths in the future.
Confidence Score: 5/5
- This PR is safe to merge - it is a pure code refactoring with no functional changes
- This is a well-executed refactoring that only reorganizes existing code without modifying logic. The activation functions are simply moved to separate files with identical implementations using template dispatchers. The cast kernels are split by scaling mode into logical subdirectories. All previous comments from reviewers focus on pre-existing issues in the moved code (style, potential bugs), not issues introduced by this refactoring. The build configuration properly handles the new file structure and adds an opt-in fast-math flag for activations.
- No files require special attention - this is a straightforward refactoring
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/CMakeLists.txt | 5/5 | Added activation files (gelu.cu, relu.cu, swiglu.cu) to arch-specific sources list and enabled optional fast-math compilation for activation kernels via NVTE_BUILD_ACTIVATION_WITH_FAST_MATH flag (defaults to ON) |
| transformer_engine/common/activation/gelu.cu | 5/5 | New file containing GELU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers |
| transformer_engine/common/activation/relu.cu | 5/5 | New file containing ReLU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers |
| transformer_engine/common/activation/swiglu.cu | 5/5 | New file containing SwiGLU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers |
| transformer_engine/common/cast/cast.cu | 5/5 | Simplified to only contain cast/quantize/dequantize API entry points - all activation functions moved to separate files, all kernel implementations moved to organized subdirectories by scaling mode |
Sequence Diagram
sequenceDiagram
participant API as Public API<br/>(cast.cu, activation/*.cu)
participant Dispatch as Dispatch Layer<br/>(dispatch/*.cuh)
participant FP8 as FP8 Kernels<br/>(fp8/*.cuh)
participant MXFP8 as MXFP8 Kernels<br/>(mxfp8/*.cuh)
participant NVFP4 as NVFP4 Kernels<br/>(nvfp4/*.cuh)
Note over API,NVFP4: Before: Monolithic cast_kernels.cuh with all scaling modes mixed
API->>Dispatch: nvte_quantize() / nvte_gelu() etc.
Note over Dispatch: Determines scaling mode<br/>from tensor metadata
alt FP8 / Delayed Scaling
Dispatch->>FP8: quantize_fp8() / gated_fp8()
FP8-->>Dispatch: Result
else MXFP8 Scaling
Dispatch->>MXFP8: quantize_mxfp8() / gated_mxfp8()
MXFP8-->>Dispatch: Result
else NVFP4 Scaling
Dispatch->>NVFP4: quantize_nvfp4() / quantize_transpose_nvfp4()
NVFP4-->>Dispatch: Result
end
Dispatch-->>API: Return quantized tensor
Note over API,NVFP4: After: Organized by scaling mode in subdirectories
5 files reviewed, no comments
Signed-off-by: Oleg Goncharov <[email protected]>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR refactors large monolithic kernel files (cast_kernels.cuh - 2188 lines, cast_gated_kernels.cuh) into a well-organized directory structure split by scaling mode and operation type.
Key Changes:
- Moved
util/cast.cu→cast/cast.cuand splitutil/cast_kernels.cuhinto modular headers - Created new
cast/dispatch/layer withquantize.cuh,gated.cuh,dequantize.cuhdispatchers - Organized kernels by precision format:
fp8/,mxfp8/,nvfp4/subdirectories - Extracted shared utilities to
cast/core/common.cuh - Updated
activation_template.hto use new dispatcher API (simpler function calls) - Removed
cast.cufrom fast-math compilation list in CMakeLists
Benefits:
- Improved navigability: each scaling mode (FP8, MXFP8, NVFP4, block scaling) now has dedicated files
- Better maintainability: changes to one format won't affect others
- Cleaner separation of concerns: dispatch logic separated from kernel implementations
- No functional changes: existing tests pass, behavior unchanged
Confidence Score: 5/5
- This PR is safe to merge - it's a pure code organization refactoring with no functional changes
- This is a well-executed refactoring that moves code without modifying behavior. The large monolithic files (2188+ lines) are cleanly split by scaling mode (FP8, MXFP8, NVFP4) into logical subdirectories. All code is moved, not modified, minimizing risk. The new dispatcher layer provides clean separation of concerns. Build configuration updated appropriately (removed cast.cu from fast-math list). Previous code review comments on style/logic issues remain valid but are pre-existing, not introduced by this PR.
- No files require special attention - this is a straightforward refactoring
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/cast/dispatch/quantize.cuh | 4/5 | New dispatcher for quantize operations, splits logic by scaling mode (FP8, MXFP8, NVFP4, block scaling) |
| transformer_engine/common/cast/dispatch/gated.cuh | 5/5 | New dispatcher for gated activation operations, cleanly separates FWD/BWD helpers by scaling mode |
| transformer_engine/common/cast/fp8/quantize_fp8.cuh | 5/5 | FP8 quantization kernels extracted from cast_kernels.cuh, code moved without modification |
| transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 5/5 | MXFP8 quantization kernels extracted from cast_kernels.cuh, organized by scaling mode |
| transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | 5/5 | NVFP4 quantization kernels extracted from cast_kernels.cuh, clean separation by format |
| transformer_engine/common/cast/core/common.cuh | 5/5 | Common cast utilities (reduce_dbias_kernel, helpers) extracted for shared use across formats |
| transformer_engine/common/CMakeLists.txt | 5/5 | Updated path from util/cast.cu to cast/cast.cu, removed cast.cu from fast-math compilation list |
| transformer_engine/common/activation/activation_template.h | 5/5 | Updated includes to use new dispatcher headers, simplified function calls to new helpers |
Sequence Diagram
sequenceDiagram
participant App as Application
participant Act as activation_template.h
participant Disp as cast/dispatch/
participant FP8 as cast/fp8/
participant MXFP8 as cast/mxfp8/
participant NVFP4 as cast/nvfp4/
participant Core as cast/core/
Note over App,Core: New Architecture: Organized by Scaling Mode
App->>Act: act_fn() / dact_fn()
Act->>Disp: quantize_fwd_helper() / quantize_bwd_helper()
alt NVTE_DELAYED_TENSOR_SCALING
Disp->>FP8: quantize() / dequantize()
FP8->>Core: reduce_dbias_kernel()
else NVTE_MXFP8_1D_SCALING
Disp->>MXFP8: quantize() / dequantize()
MXFP8->>Core: reduce_dbias_kernel()
else NVTE_NVFP4_1D_SCALING
Disp->>NVFP4: quantize_transpose()
NVFP4->>NVFP4: core_nvfp4.cuh helpers
else NVTE_BLOCK_SCALING_*
Disp->>Disp: quantize_transpose_*_blockwise()
end
App->>Act: gated_act_fn() / dgated_act_fn()
Act->>Disp: quantize_gated_fwd/bwd_helper()
alt NVTE_DELAYED_TENSOR_SCALING
Disp->>FP8: cast_gated_tma() / cast_gated_fwd()
else NVTE_MXFP8_1D_SCALING
Disp->>MXFP8: quantize_gated()
end
Note over Disp,NVFP4: Old: 2188 lines in cast_kernels.cuh<br/>New: Split into fp8/, mxfp8/, nvfp4/, core/
No files reviewed, no comments
* Separated gated and dequantize kernels Signed-off-by: Oleg Goncharov <[email protected]> * Separated quantize, dequantize and gated functions Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed lint issues Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed persistent lint issues Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added missing compute capability 10.0 check for Quantize FP8 TMA kernels Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed the issue which was added again by autofix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite Signed-off-by: Oleg Goncharov <[email protected]> * Removed unsupported template arguments in NVFP4 quantize Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed undefined symbol error Signed-off-by: Oleg Goncharov <[email protected]> * Fixed condition Signed-off-by: Oleg Goncharov <[email protected]> * Fixed CUDA version check Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed arch conditions order Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov <[email protected]> * Clean up Signed-off-by: Oleg Goncharov <[email protected]> * Small fix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Small fix Signed-off-by: Oleg Goncharov <[email protected]> * Fixes per the PR review Signed-off-by: Oleg Goncharov <[email protected]> * Fix Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Split quantize helper into two (FWD and BWD) functions Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list Signed-off-by: Oleg Goncharov <[email protected]> * Enabled fast math for activations by default Signed-off-by: Oleg Goncharov <[email protected]> * Disabled fast math for activations by default Signed-off-by: Oleg Goncharov <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Oleg Goncharov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
Breaks up the large
cast_kernels.cuhandcast_gated_kernels.cuhinto smaller headers organized by scaling mode.No functional or behavior changes: code is moved, not modified. This improves structure, readability, and maintainability (easier to navigate/extend specific scaling paths). Build includes/exports updated accordingly; tests unaffected.
Fixes # (issue)
Type of change
Changes
cast_kernels.cuhandcast_gated_kernels.cuhinto smaller headers organized by scaling mode.Checklist: